# Copyright 2025 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Generate data for mcore dopt_attn UT of inference."""
import numpy as np

BATCH_SIZE = 2
SEQ_LENGTH = 2
HIDDEN_SIZE = 32
NUM_HEADS = 2


def get_init_params(num_kv_heads):
    """Generate initialization parameters"""
    q_shape = (BATCH_SIZE, SEQ_LENGTH, HIDDEN_SIZE)
    kv_shape = (
        BATCH_SIZE, SEQ_LENGTH,
        int(num_kv_heads * HIDDEN_SIZE / NUM_HEADS)
    )
    query = np.random.normal(0, 0.01, q_shape)
    key = np.random.normal(0, 0.01, kv_shape)
    value = np.random.normal(0, 0.01, kv_shape)
    return {
        "query": query,
        "key": key,
        "value": value
    }


def get_golden() -> dict[str, np.ndarray]:
    """Generate golden data for test."""
    output_1 = np.array(
        [[[-9.06825624e-03, -1.84027303e-03, 7.27286236e-03,
           -8.50748096e-04, 6.42714184e-03, -1.82695605e-03,
           2.04526610e-03, -8.61833710e-03, -7.73729756e-03,
           -4.42506233e-03, 6.58267923e-03, -2.88367388e-03,
           -6.97717257e-03, 6.79270737e-03, -1.02970153e-02,
           7.22385375e-05, -9.07058176e-03, -1.84019836e-03,
           7.27265561e-03, -8.52211553e-04, 6.42751623e-03,
           -1.82832603e-03, 2.04393314e-03, -8.61854479e-03,
           -7.73755554e-03, -4.42678388e-03, 6.58102008e-03,
           -2.88395165e-03, -6.97663939e-03, 6.79184915e-03,
           -1.02981320e-02, 7.28907762e-05],
          [-9.06925369e-03, -1.84024067e-03, 7.27277342e-03,
           -8.51375924e-04, 6.42730203e-03, -1.82754383e-03,
           2.04469403e-03, -8.61842558e-03, -7.73740793e-03,
           -4.42580087e-03, 6.58196723e-03, -2.88379285e-03,
           -6.97694346e-03, 6.79233903e-03, -1.02974940e-02,
           7.25184218e-05, -9.07090306e-03, -1.84018747e-03,
           7.27262627e-03, -8.52413534e-04, 6.42756699e-03,
           -1.82851497e-03, 2.04374897e-03, -8.61857273e-03,
           -7.73759047e-03, -4.42702137e-03, 6.58079050e-03,
           -2.88398983e-03, -6.97656488e-03, 6.79172995e-03,
           -1.02982847e-02, 7.29808453e-05]],
         [[5.18906908e-03, 2.90517928e-03, 5.38176624e-03,
           -1.54857070e-03, 1.42390896e-02, -2.23570131e-03,
           8.43618298e-04, -6.26045698e-03, 4.06868663e-03,
           5.60942292e-03, -7.04066537e-04, 5.08407457e-03,
           1.74883977e-02, 1.10973176e-02, -1.44029444e-03,
           7.42228236e-03, 5.18907094e-03, 2.90588895e-03,
           5.38180210e-03, -1.54859561e-03, 1.42386407e-02,
           -2.23740982e-03, 8.42688081e-04, -6.25937060e-03,
           4.06930409e-03, 5.60929440e-03, -7.05415267e-04,
           5.08490531e-03, 1.74880922e-02, 1.10982480e-02,
           -1.43961632e-03, 7.42423115e-03],
          [5.18906955e-03, 2.90554087e-03, 5.38178394e-03,
           -1.54858339e-03, 1.42388614e-02, -2.23657233e-03,
           8.43144080e-04, -6.25990285e-03, 4.06900095e-03,
           5.60935773e-03, -7.04753795e-04, 5.08449785e-03,
           1.74882431e-02, 1.10977916e-02, -1.43994880e-03,
           7.42327515e-03, 5.18906955e-03, 2.90529383e-03,
           5.38177229e-03, -1.54857477e-03, 1.42390179e-02,
           -2.23597721e-03, 8.43468239e-04, -6.26028236e-03,
           4.06878628e-03, 5.60940243e-03, -7.04284175e-04,
           5.08420914e-03, 1.74883492e-02, 1.10974675e-02,
           -1.44018512e-03, 7.42259668e-03]]], dtype=np.float32
    )

    output_2 = np.array(
        [[[6.3402881e-03, -3.3660554e-03, -1.4444257e-03, -8.0461344e-03,
           -2.5268497e-03, 7.1405637e-04, -5.4758051e-03, 2.3154295e-03,
           -7.7178353e-03, -4.8816805e-03, -1.5195193e-02, -3.3191510e-03,
           -2.5233140e-03, -8.2748593e-04, -2.6979309e-03, -1.4895399e-03,
           9.5586979e-04, 1.7731448e-04, 1.9400379e-03, 3.9372500e-03,
           -9.0299202e-03, 4.4953106e-03, -9.0429299e-03, -5.2149305e-03,
           -6.9167255e-03, 1.5271897e-03, -1.5173258e-02, -5.5669569e-03,
           -6.3685403e-04, 1.2024220e-03, 1.7976601e-02, -1.9472408e-03],
          [6.3402713e-03, -3.3667018e-03, -1.4446254e-03, -8.0462266e-03,
           -2.5268241e-03, 7.1404869e-04, -5.4759672e-03, 2.3156214e-03,
           -7.7165477e-03, -4.8830407e-03, -1.5195013e-02, -3.3196299e-03,
           -2.5236532e-03, -8.2746305e-04, -2.6976173e-03, -1.4904591e-03,
           9.5564674e-04, 1.7694275e-04, 1.9404497e-03, 3.9369222e-03,
           -9.0298494e-03, 4.4953972e-03, -9.0428125e-03, -5.2155368e-03,
           -6.9166031e-03, 1.5267493e-03, -1.5173142e-02, -5.5665937e-03,
           -6.3676835e-04, 1.2024306e-03, 1.7976811e-02, -1.9473977e-03]],
         [[-1.3177248e-02, -8.7196082e-03, -4.4969837e-03, 8.8495445e-03,
           -6.6079047e-05, 4.5926729e-03, 1.4041479e-04, -3.4923179e-03,
           5.2190269e-03, -1.3577295e-02, 9.7465888e-03, -1.3444002e-03,
           1.2624572e-03, -5.3344131e-03, -4.3867803e-03, -5.6758854e-03,
           1.8999334e-03, -9.9857384e-03, -6.0378467e-03, -1.1245237e-02,
           -2.6462374e-03, -1.7408989e-04, -1.3619080e-02, -8.8263461e-03,
           -7.1176738e-03, 6.9640032e-03, 1.8937998e-03, -1.6155215e-03,
           -7.9378895e-03, 2.9506660e-03, 3.6616573e-05, -2.2505156e-03],
          [-1.3176448e-02, -8.7197758e-03, -4.4968980e-03, 8.8499207e-03,
           -6.6713350e-05, 4.5928918e-03, 1.4191275e-04, -3.4922399e-03,
           5.2196230e-03, -1.3577253e-02, 9.7467788e-03, -1.3441592e-03,
           1.2632419e-03, -5.3346725e-03, -4.3863538e-03, -5.6762993e-03,
           1.8989743e-03, -9.9870786e-03, -6.0368893e-03, -1.1244749e-02,
           -2.6470430e-03, -1.7253826e-04, -1.3620944e-02, -8.8264663e-03,
           -7.1180952e-03, 6.9634058e-03, 1.8955427e-03, -1.6156896e-03,
           -7.9373559e-03, 2.9521899e-03, 3.7390335e-05, -2.2501319e-03]]], dtype=np.float32
    )

    return {
        "output_1": output_1,
        "output_2": output_2,
    }


def get_gpu_data() -> dict[str, np.ndarray]:
    """Generate gpu data for test."""
    output_1 = np.array(
        [[[-9.0942e-03, -1.8463e-03, 7.2937e-03, -8.5449e-04,
           6.4392e-03, -1.8158e-03, 2.0447e-03, -8.6060e-03,
           -7.7515e-03, -4.4250e-03, 6.5918e-03, -2.8839e-03,
           -6.9885e-03, 6.7749e-03, -1.0254e-02, 7.6294e-05,
           -9.0942e-03, -1.8463e-03, 7.2937e-03, -8.5449e-04,
           6.4392e-03, -1.8158e-03, 2.0447e-03, -8.6060e-03,
           -7.7515e-03, -4.4250e-03, 6.5918e-03, -2.8839e-03,
           -6.9885e-03, 6.7749e-03, -1.0254e-02, 7.6294e-05],
          [-9.0942e-03, -1.8463e-03, 7.2937e-03, -8.5449e-04,
           6.4392e-03, -1.8158e-03, 2.0447e-03, -8.6060e-03,
           -7.7515e-03, -4.4250e-03, 6.5918e-03, -2.8839e-03,
           -6.9885e-03, 6.7749e-03, -1.0254e-02, 7.6294e-05,
           -9.0942e-03, -1.8463e-03, 7.2937e-03, -8.5449e-04,
           6.4392e-03, -1.8158e-03, 2.0447e-03, -8.6060e-03,
           -7.7515e-03, -4.4250e-03, 6.5918e-03, -2.8839e-03,
           -6.9885e-03, 6.7749e-03, -1.0254e-02, 7.6294e-05]],
         [[5.1880e-03, 2.9144e-03, 5.4016e-03, -1.5488e-03,
           1.4221e-02, -2.2583e-03, 8.3923e-04, -6.2561e-03,
           4.0588e-03, 5.6152e-03, -7.0190e-04, 5.0964e-03,
           1.7456e-02, 1.1108e-02, -1.4343e-03, 7.4463e-03,
           5.1880e-03, 2.9144e-03, 5.4016e-03, -1.5488e-03,
           1.4221e-02, -2.2583e-03, 8.3923e-04, -6.2561e-03,
           4.0588e-03, 5.6152e-03, -7.0190e-04, 5.0964e-03,
           1.7456e-02, 1.1108e-02, -1.4343e-03, 7.4463e-03],
          [5.1880e-03, 2.9144e-03, 5.4016e-03, -1.5488e-03,
           1.4221e-02, -2.2583e-03, 8.3923e-04, -6.2561e-03,
           4.0588e-03, 5.6152e-03, -7.0190e-04, 5.0964e-03,
           1.7456e-02, 1.1108e-02, -1.4343e-03, 7.4463e-03,
           5.1880e-03, 2.9144e-03, 5.4016e-03, -1.5488e-03,
           1.4221e-02, -2.2583e-03, 8.3923e-04, -6.2561e-03,
           4.0588e-03, 5.6152e-03, -7.0190e-04, 5.0964e-03,
           1.7456e-02, 1.1108e-02, -1.4343e-03, 7.4463e-03]]], dtype=np.float16
    )

    output_2 = np.array(
        [[[6.348e-03, -3.372e-03, -1.450e-03, -8.057e-03, -2.533e-03,
           7.133e-04, -5.493e-03, 2.319e-03, -7.690e-03, -4.913e-03,
           -1.526e-02, -3.326e-03, -2.518e-03, -8.278e-04, -2.686e-03,
           -1.495e-03, 9.613e-04, 1.831e-04, 1.938e-03, 3.937e-03,
           -9.033e-03, 4.486e-03, -9.033e-03, -5.219e-03, -6.927e-03,
           1.511e-03, -1.520e-02, -5.554e-03, -6.371e-04, 1.205e-03,
           1.794e-02, -1.953e-03],
          [6.348e-03, -3.372e-03, -1.450e-03, -8.057e-03, -2.533e-03,
           7.133e-04, -5.493e-03, 2.319e-03, -7.690e-03, -4.913e-03,
           -1.526e-02, -3.326e-03, -2.518e-03, -8.278e-04, -2.686e-03,
           -1.495e-03, 9.613e-04, 1.831e-04, 1.938e-03, 3.937e-03,
           -9.033e-03, 4.486e-03, -9.033e-03, -5.219e-03, -6.927e-03,
           1.511e-03, -1.520e-02, -5.554e-03, -6.371e-04, 1.205e-03,
           1.794e-02, -1.953e-03]],
         [[-1.318e-02, -8.728e-03, -4.517e-03, 8.850e-03, -9.155e-05,
           4.608e-03, 1.221e-04, -3.494e-03, 5.219e-03, -1.361e-02,
           9.766e-03, -1.350e-03, 1.282e-03, -5.341e-03, -4.395e-03,
           -5.676e-03, 1.900e-03, -1.001e-02, -6.042e-03, -1.129e-02,
           -2.640e-03, -1.831e-04, -1.367e-02, -8.850e-03, -7.111e-03,
           6.958e-03, 1.877e-03, -1.617e-03, -7.935e-03, 2.945e-03,
           3.052e-05, -2.243e-03],
          [-1.318e-02, -8.728e-03, -4.517e-03, 8.850e-03, -9.155e-05,
           4.608e-03, 1.221e-04, -3.494e-03, 5.219e-03, -1.361e-02,
           9.766e-03, -1.350e-03, 1.282e-03, -5.341e-03, -4.395e-03,
           -5.676e-03, 1.900e-03, -1.001e-02, -6.042e-03, -1.129e-02,
           -2.640e-03, -1.831e-04, -1.367e-02, -8.850e-03, -7.111e-03,
           6.958e-03, 1.877e-03, -1.617e-03, -7.935e-03, 2.945e-03,
           3.052e-05, -2.243e-03]]], dtype=np.float16
    )

    return {
        "output_1": output_1,
        "output_2": output_2,
    }


GOLDEN_DATA = get_golden()
GPU_DATA = get_gpu_data()
